{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mmcv\n",
    "\n",
    "file_path = '/home/cihangxie/shaoyuan/BEV-Attack/test.pkl'\n",
    "data = mmcv.load(file_path)\n",
    "\n",
    "outputs = data['outputs']\n",
    "gt_bboxes_3d = data['gt_bboxes_3d']\n",
    "gt_labels_3d = data['gt_labels_3d']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs_bbox = outputs[0]['pts_bbox']['boxes_3d'].tensor\n",
    "outputs_scores = outputs[0]['pts_bbox']['scores_3d']\n",
    "outputs_labels = outputs[0]['pts_bbox']['labels_3d']\n",
    "gt_bbox = gt_bboxes_3d[0].data[0][0].tensor\n",
    "gt_label = gt_labels_3d[0].data[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "CLASSES = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',\n",
    "            'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',\n",
    "            'barrier')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 8\n",
    "cls = 'traffic_cone'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_output_mask = outputs_labels == i\n",
    "cls_gt_mask = gt_label == i\n",
    "\n",
    "cls_output_bbox = outputs_bbox[cls_output_mask]\n",
    "cls_output_conf = outputs_scores[cls_output_mask]\n",
    "cls_gt_bbox = gt_bbox[cls_gt_mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# for i in range(len(cls_output_conf)):\n",
    "\n",
    "taken = set()\n",
    "\n",
    "for i, bbox in enumerate(cls_output_bbox):\n",
    "\n",
    "    min_dist = np.inf\n",
    "    for j, this_bbox in enumerate(cls_gt_bbox):\n",
    "\n",
    "        if j not in taken:\n",
    "            dist = np.linalg.norm(np.array(bbox[:2]) - np.array(this_bbox[:2]))\n",
    "            print(dist)\n",
    "            if dist < min_dist:\n",
    "                min_dist = dist\n",
    "                \n",
    "    is_match = min_dist < 4\n",
    "\n",
    "    if is_match:\n",
    "        taken.add(j)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(cls_gt_bbox)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(taken)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dist)\n",
    "print(dist.argmin())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('bev')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c02ad8a9d7f137533d1d53e5dfa65b2892f6efa88d9e54eaf40000517f125423"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
